{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Single-cell RNA-seq imputation using DeepImpute\n",
    "\n",
    "Here is a comprehensive tutorial to understand the functionnalities of DeepImpute."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from deepimpute.multinet import MultiNet\n",
    "import pandas as pd\n",
    "\n",
    "# Load dataset using pandas\n",
    "data = pd.read_csv('test.csv',index_col=0)\n",
    "print('Working on {} cells and {} genes'.format(*data.shape))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create a DeepImpute multinet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using default parameters\n",
    "multinet = MultiNet() "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using custom parameters\n",
    "NN_params = {\n",
    "        'learning_rate': 1e-4,\n",
    "        'batch_size': 64,\n",
    "        'max_epochs': 200,\n",
    "        'ncores': 5,\n",
    "        'sub_outputdim': 512,\n",
    "        'architecture': [\n",
    "            {\"type\": \"dense\", \"activation\": \"relu\", \"neurons\": 200},\n",
    "            {\"type\": \"dropout\", \"activation\": \"dropout\", \"rate\": 0.3}]\n",
    "    }\n",
    "\n",
    "multinet = MultiNet(**NN_params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fit the networks"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using all the data\n",
    "multinet.fit(data,cell_subset=1,minVMR=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using 80% of the data\n",
    "multinet.fit(data,cell_subset=0.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using 200 cells (randomly selected)\n",
    "multinet.fit(data,cell_subset=200)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Custom fit\n",
    "trainingData = data.iloc[100:250,:]\n",
    "multinet.fit(trainingData)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imputation\n",
    "\n",
    "The imputation can be done on any dataset as long as the gene labels are the same"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "imputedData = multinet.predict(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Visualization"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "limits = [0,100]\n",
    "\n",
    "fig,ax = plt.subplots()\n",
    "\n",
    "jitter = np.random.normal(0,1,data.size) # Add some jittering to better see the point density\n",
    "ax.scatter(data.values.flatten()+jitter,imputedData.values.flatten(),s=2)\n",
    "ax.plot(limits,limits,'r-.',linewidth=2)\n",
    "ax.set_xlim(limits)\n",
    "ax.set_ylim(limits)\n",
    "\n",
    "plt.show()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Scoring\n",
    "Display training metrics (MSE and Pearson's correlation on the test data)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "multinet.test_metrics"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
